'''
@version: Anaconda
@author: 吃口白日梦
@contact: 460958592@qq.com
@software: PyCharm
@file: Sentinel-2.zip convert tif(true color)
@time: 2022/7/11 凌晨 00:03

'''
from osgeo import gdal
import numpy as np
import os
import os
import os
import tqdm
import math
import time
# 受postgis影响，重新设置环境变量
os.environ['CPL_ZIP_ENCODING'] = 'UTF-8'
os.environ['PROJ_LIB'] = r'D:\anaconda3\envs\pytorch_GPU\Lib\site-packages\pyproj\proj_dir\share\proj'
os.environ['GDAL_DATA'] = r'D:\anaconda3\envs\pytorch_GPU\Library\share'


# 重采样
def GdalReprojectImage(path,resampleFactor,num):
    """
    :param path: 文件夾路徑
    :param resampleFactor: 重采樣因子，如等於2則為重采樣為原來的1/2
    :param num:选择的元组
    """

    res = os.listdir(path)
    # 遍历这些文件
    for var in tqdm.tqdm(res):
        # print(f'目前第{start}文件')
        # 拼接路径
        filename = os.path.join(path, var)
        print(filename)
        # 打开栅格文件
        root_ds = gdal.Open(filename)
        # 獲取柵格子數據集
        ds_list = root_ds.GetSubDatasets()
        # print(ds_list)

        # 返回结果是一个list，list中的每个元素是一个tuple，每个tuple中包含了对数据集的路径，元数据等的描述信息
        # tuple中的第一个元素描述的是数据子集的全路径
        #分辨率为20m
        sub = gdal.Open(ds_list[1][0])
        xsize = sub.RasterXSize  # 栅格w
        ysize = sub.RasterYSize  # 栅格h
        # 按块读取栅格(将数据转为ndarray)
        sub_arr = sub.ReadAsArray()
        # 栅格图像的c、h、w
        dimen = np.array(sub_arr).shape
        # 创建.tif文件
        band_count = sub.RasterCount  # 波段数
        # 注意：这里的总波段为11
        bands = sub.GetRasterBand(1)


        out_tif_name = filename.split(".zip")[0] + ".tif"  # # 输出格式
        # 输出宽高
        outWidth = int(xsize * resampleFactor)
        outHeight = int(ysize * resampleFactor)
        # 驱动器
        driver = gdal.GetDriverByName("GTiff")
        out_tif = driver.Create(out_tif_name, outWidth, outHeight, 4, bands.DataType)
        out_tif.SetProjection(sub.GetProjection())  # 设置投影坐标

        geoTransforms = list(sub.GetGeoTransform())
        geoTransforms[1] = geoTransforms[1] / resampleFactor
        geoTransforms[5] = geoTransforms[5] / resampleFactor
        outGeoTransform = tuple(geoTransforms)
        out_tif.SetGeoTransform(outGeoTransform)  # 地理坐标变换

        # # 遍历索引及波段
        for index, band in enumerate(sub_arr):
            # print(band)
            outband = out_tif.GetRasterBand(index + 1)
            outband.SetNoDataValue(-9999)
            nBlockSize = 2048
            i = 0
            j = 0
            b = xsize * ysize
            # 进度条参数
            XBlockcount = math.ceil(xsize / nBlockSize)
            YBlockcount = math.ceil(ysize / nBlockSize)

            try:
                with tqdm.tqdm(total=XBlockcount * YBlockcount, iterable='iterable', desc='第%i波段:' % index, mininterval=10) as pbar:
                    # with tqdm(total=XBlockcount*YBlockcount) as pbar:
                    # print(pbar)
                    while i < ysize:
                        while j < xsize:
                            # 保存分块大小
                            nXBK = nBlockSize
                            nYBK = nBlockSize

                            # 最后不够分块的区域，有多少读取多少
                            if i + nBlockSize > ysize:
                                nYBK = ysize - i
                            if j + nBlockSize > xsize:
                                nXBK = xsize - j

                            # 分块读取影像
                            Image = band.ReadAsArray(j, i, nXBK, nYBK)
                            outband.WriteArray(Image, j, i)

                            j = j + nXBK
                            time.sleep(1)
                            pbar.update(1)
                        j = 0
                        i = i + nYBK
            except KeyboardInterrupt:
                pbar.close()
                raise
            pbar.close()
        out_tif.FlushCache()  # 写入硬盘
        # out_tif.close()  # 关闭tif文件
        out_tif = None  # 关闭tif文件
        print('转换成功')


# 按分辨率输出
def Sentinel2(path):
    # 获取文件夹的所有文件名
    res = os.listdir(path)
    # 遍历这些文件
    for var in tqdm.tqdm(res):
        # print(f'目前第{start}文件')
        # 拼接路径
        filename = os.path.join(path, var)
        print(filename)
        # 打开栅格文件
        root_ds = gdal.Open(filename)
        # 獲取柵格子數據集
        ds_list = root_ds.GetSubDatasets()
        # print(ds_list)
        '''
        元组1（数据路径，10米的波段的信息）4个波段
        [('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:10m:EPSG_32648'
        , 'Bands B2, B3, B4, B8 with 10m resolution, UTM 48N'),
        元组2 （数据路径，20米波段的信息）6个波段
        ('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:20m:EPSG_32648', 
        'Bands B5, B6, B7, B8A, B11, B12, AOT, CLD, SCL, SNW, WVP with 20m resolution, UTM 48N'), 
        元组3（数据路径，60米的波段信息）
        ('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:60m:EPSG_32648',
        'Bands B1, B9, AOT, CLD, SCL, SNW, WVP with 60m resolution, UTM 48N'), 2个波段
        元组4
        ('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:TCI:EPSG_32648', 
        'True color image, UTM 48N')]

        '''

        # 返回结果是一个list，list中的每个元素是一个tuple，每个tuple中包含了对数据集的路径，元数据等的描述信息
        # tuple中的第一个元素描述的是数据子集的全路径
        # 取出列表元组的数据（取出真彩色）
        sub = gdal.Open(ds_list[1][0])
        xsize = sub.RasterXSize  # 栅格w
        ysize = sub.RasterYSize  # 栅格h
        band_count1 = sub.RasterCount
        # sub = gdal.Open(ds_list[1][0])
        # 分辨率为20m
        resampleFactor =2
        Band2_Blue = sub.GetRasterBand(2).ReadAsArray(buf_xsize=xsize*resampleFactor, buf_ysize=ysize*resampleFactor)
        Band3_Green = sub.GetRasterBand(3).ReadAsArray(buf_xsize=xsize*resampleFactor, buf_ysize=ysize*resampleFactor)
        Band4_Red = sub.GetRasterBand(4).ReadAsArray(buf_xsize=xsize*resampleFactor, buf_ysize=ysize*resampleFactor)
        Band8_NIR = sub.GetRasterBand(8).ReadAsArray(buf_xsize=xsize*resampleFactor, buf_ysize=ysize*resampleFactor)
        Band11_SWIR = sub.GetRasterBand(10).ReadAsArray(buf_xsize=xsize*resampleFactor, buf_ysize=ysize*resampleFactor)
        Band12_SWIR = sub.GetRasterBand(11).ReadAsArray(buf_xsize=xsize*resampleFactor, buf_ysize=ysize*resampleFactor)
        # 按块读取栅格(将数据转为ndarray)
        # sub_arr = sub.ReadAsArray()
        band_count = 6  # 波段数
        # print(band_count)
        bands = sub.GetRasterBand(2)


        out_tif_name = filename.split(".zip")[0] + ".tif"  # # 输出格式
        print(out_tif_name)
        # 输出宽高
        outWidth = int(xsize * resampleFactor)
        outHeight = int(ysize * resampleFactor)
        # 驱动器
        driver = gdal.GetDriverByName("GTiff")
        out_tif = driver.Create(out_tif_name, outWidth, outHeight, band_count, bands.DataType)
        out_tif.SetProjection(sub.GetProjection())  # 设置投影坐标
        geoTransforms = list(sub.GetGeoTransform())

        geoTransforms[1] = geoTransforms[1] / resampleFactor
        geoTransforms[5] = geoTransforms[5] / resampleFactor
        outGeoTransform = tuple(geoTransforms)
        out_tif.SetGeoTransform(outGeoTransform)
        out_tif.GetRasterBand(1).WriteArray(Band2_Blue)  # 将每个波段的数据写入内存
        out_tif.GetRasterBand(2).WriteArray(Band3_Green)
        out_tif.GetRasterBand(3).WriteArray(Band4_Red)
        out_tif.GetRasterBand(4).WriteArray(Band8_NIR)
        out_tif.GetRasterBand(5).WriteArray(Band11_SWIR)
        out_tif.GetRasterBand(6).WriteArray(Band12_SWIR)
        out_tif.FlushCache()  # 写入硬盘
        out_tif.BuildOverviews('average', [1, 2, 4, 8, 16, 32])  # 构建金字塔
        gdal.ReprojectImage(
            sub,
            out_tif,
            sub.GetProjection(),
            sub.GetProjection(),
            gdal.gdalconst.GRA_NearestNeighbour,
            0.0, 0.0,
        )
        out_tif = None  # 关闭tif文件


if __name__ == "__main__":
    path = r'G:\哨兵数据库\珠江口\2022\sel(珠江口2022年03齐全)\sentinel-2'
    Sentinel2(path)
    # GdalReprojectImage(path, 2, 1)
    '''
           元组1（数据路径，10米的波段的信息）
           [('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:10m:EPSG_32648'
           , 'Bands B2, B3, B4, B8 with 10m resolution, UTM 48N'),
           元组2 （数据路径，20米波段的信息）
           ('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:20m:EPSG_32648', 
           'Bands B5, B6, B7, B8A, B11, B12, AOT, CLD, SCL, SNW, WVP with 20m resolution, UTM 48N'), 
           元组3（数据路径，60米的波段信息）
           ('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:60m:EPSG_32648',
           'Bands B1, B9, AOT, CLD, SCL, SNW, WVP with 60m resolution, UTM 48N'), 
           元组4
           ('SENTINEL2_L2A:/vsizip/G:\\新加坡20219-12\\sentinel-2\\S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.zip/S2A_MSIL2A_20210206T030911_N0214_R075_T48NVG_20210206T064859.SAFE/MTD_MSIL2A.xml:TCI:EPSG_32648', 
           'True color image, UTM 48N')]

           '''
